# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for recurrent layers functionality other than GRU, LSTM, SimpleRNN.

See also: lstm_test.py, gru_test.py, simplernn_test.py.
"""


import collections

import numpy as np
import tensorflow.compat.v2 as tf
from absl.testing import parameterized

import keras
from keras.engine import base_layer_utils
from keras.layers.rnn import gru
from keras.layers.rnn import gru_v1
from keras.layers.rnn import lstm
from keras.layers.rnn import lstm_v1
from keras.testing_infra import test_combinations
from keras.testing_infra import test_utils
from keras.utils import generic_utils

# isort: off
from tensorflow.python.checkpoint import (
    checkpoint as trackable_util,
)

# Used for nested input/output/state RNN test.
NestedInput = collections.namedtuple("NestedInput", ["t1", "t2"])
NestedState = collections.namedtuple("NestedState", ["s1", "s2"])


@test_combinations.run_all_keras_modes
class RNNTest(test_combinations.TestCase):
    def test_minimal_rnn_cell_non_layer(self):
        class MinimalRNNCell:
            def __init__(self, units, input_dim):
                self.units = units
                self.state_size = units
                self.kernel = keras.backend.variable(
                    np.random.random((input_dim, units))
                )

            def call(self, inputs, states):
                prev_output = states[0]
                output = keras.backend.dot(inputs, self.kernel) + prev_output
                return output, [output]

        # Basic test case.
        cell = MinimalRNNCell(32, 5)
        x = keras.Input((None, 5))
        layer = keras.layers.RNN(cell)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

        # Test stacking.
        cells = [
            MinimalRNNCell(8, 5),
            MinimalRNNCell(32, 8),
            MinimalRNNCell(32, 32),
        ]
        layer = keras.layers.RNN(cells)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    def test_minimal_rnn_cell_non_layer_multiple_states(self):
        class MinimalRNNCell:
            def __init__(self, units, input_dim):
                self.units = units
                self.state_size = (units, units)
                self.kernel = keras.backend.variable(
                    np.random.random((input_dim, units))
                )

            def call(self, inputs, states):
                prev_output_1 = states[0]
                prev_output_2 = states[1]
                output = keras.backend.dot(inputs, self.kernel)
                output += prev_output_1
                output -= prev_output_2
                return output, [output * 2, output * 3]

        # Basic test case.
        cell = MinimalRNNCell(32, 5)
        x = keras.Input((None, 5))
        layer = keras.layers.RNN(cell)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

        # Test stacking.
        cells = [
            MinimalRNNCell(8, 5),
            MinimalRNNCell(16, 8),
            MinimalRNNCell(32, 16),
        ]
        layer = keras.layers.RNN(cells)
        self.assertEqual(layer.cell.state_size, ((8, 8), (16, 16), (32, 32)))
        self.assertEqual(layer.cell.output_size, 32)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    def test_minimal_rnn_cell_layer(self):
        class MinimalRNNCell(keras.layers.Layer):
            def __init__(self, units, **kwargs):
                self.units = units
                self.state_size = units
                super().__init__(**kwargs)

            def build(self, input_shape):
                self.kernel = self.add_weight(
                    shape=(input_shape[-1], self.units),
                    initializer="uniform",
                    name="kernel",
                )
                self.recurrent_kernel = self.add_weight(
                    shape=(self.units, self.units),
                    initializer="uniform",
                    name="recurrent_kernel",
                )
                self.built = True

            def call(self, inputs, states):
                prev_output = states[0]
                h = keras.backend.dot(inputs, self.kernel)
                output = h + keras.backend.dot(
                    prev_output, self.recurrent_kernel
                )
                return output, [output]

            def get_config(self):
                config = {"units": self.units}
                base_config = super().get_config()
                return dict(list(base_config.items()) + list(config.items()))

        # Test basic case.
        x = keras.Input((None, 5))
        cell = MinimalRNNCell(32)
        layer = keras.layers.RNN(cell)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

        # Test basic case serialization.
        x_np = np.random.random((6, 5, 5))
        y_np = model.predict(x_np)
        weights = model.get_weights()
        config = layer.get_config()
        with generic_utils.CustomObjectScope(
            {"MinimalRNNCell": MinimalRNNCell}
        ):
            layer = keras.layers.RNN.from_config(config)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.set_weights(weights)
        y_np_2 = model.predict(x_np)
        self.assertAllClose(y_np, y_np_2, atol=1e-4)

        # Test stacking.
        cells = [MinimalRNNCell(8), MinimalRNNCell(12), MinimalRNNCell(32)]
        layer = keras.layers.RNN(cells)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

        # Test stacked RNN serialization.
        x_np = np.random.random((6, 5, 5))
        y_np = model.predict(x_np)
        weights = model.get_weights()
        config = layer.get_config()
        with generic_utils.CustomObjectScope(
            {"MinimalRNNCell": MinimalRNNCell}
        ):
            layer = keras.layers.RNN.from_config(config)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.set_weights(weights)
        y_np_2 = model.predict(x_np)
        self.assertAllClose(y_np, y_np_2, atol=1e-4)

    def test_minimal_rnn_cell_abstract_rnn_cell(self):
        class MinimalRNNCell(keras.layers.AbstractRNNCell):
            def __init__(self, units, **kwargs):
                self.units = units
                super().__init__(**kwargs)

            @property
            def state_size(self):
                return self.units

            def build(self, input_shape):
                self.kernel = self.add_weight(
                    shape=(input_shape[-1], self.units),
                    initializer="uniform",
                    name="kernel",
                )
                self.recurrent_kernel = self.add_weight(
                    shape=(self.units, self.units),
                    initializer="uniform",
                    name="recurrent_kernel",
                )
                self.built = True

            def call(self, inputs, states):
                prev_output = states[0]
                h = keras.backend.dot(inputs, self.kernel)
                output = h + keras.backend.dot(
                    prev_output, self.recurrent_kernel
                )
                return output, output

            @property
            def output_size(self):
                return self.units

        cell = MinimalRNNCell(32)
        x = keras.Input((None, 5))
        layer = keras.layers.RNN(cell)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

        # Test stacking.
        cells = [MinimalRNNCell(8), MinimalRNNCell(16), MinimalRNNCell(32)]
        layer = keras.layers.RNN(cells)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    def test_rnn_with_time_major(self):
        batch = 10
        time_step = 5
        embedding_dim = 4
        units = 3

        # Test basic case.
        x = keras.Input((time_step, embedding_dim))
        time_major_x = keras.layers.Lambda(
            lambda t: tf.transpose(t, [1, 0, 2])
        )(x)
        layer = keras.layers.SimpleRNN(
            units, time_major=True, return_sequences=True
        )
        self.assertEqual(
            layer.compute_output_shape(
                (time_step, None, embedding_dim)
            ).as_list(),
            [time_step, None, units],
        )
        y = layer(time_major_x)
        self.assertEqual(layer.output_shape, (time_step, None, units))

        y = keras.layers.Lambda(lambda t: tf.transpose(t, [1, 0, 2]))(y)

        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, time_step, embedding_dim)),
            np.zeros((batch, time_step, units)),
        )

        # Test stacking.
        x = keras.Input((time_step, embedding_dim))
        time_major_x = keras.layers.Lambda(
            lambda t: tf.transpose(t, [1, 0, 2])
        )(x)
        cell_units = [10, 8, 6]
        cells = [keras.layers.SimpleRNNCell(cell_units[i]) for i in range(3)]
        layer = keras.layers.RNN(cells, time_major=True, return_sequences=True)
        y = layer(time_major_x)
        self.assertEqual(layer.output_shape, (time_step, None, cell_units[-1]))

        y = keras.layers.Lambda(lambda t: tf.transpose(t, [1, 0, 2]))(y)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, time_step, embedding_dim)),
            np.zeros((batch, time_step, cell_units[-1])),
        )

        # Test masking.
        x = keras.Input((time_step, embedding_dim))
        time_major = keras.layers.Lambda(lambda t: tf.transpose(t, [1, 0, 2]))(
            x
        )
        mask = keras.layers.Masking()(time_major)
        rnn = keras.layers.SimpleRNN(
            units, time_major=True, return_sequences=True
        )(mask)
        y = keras.layers.Lambda(lambda t: tf.transpose(t, [1, 0, 2]))(rnn)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, time_step, embedding_dim)),
            np.zeros((batch, time_step, units)),
        )

        # Test layer output
        x = keras.Input((time_step, embedding_dim))
        rnn_1 = keras.layers.SimpleRNN(units, return_sequences=True)
        y = rnn_1(x)

        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, time_step, embedding_dim)),
            np.zeros((batch, time_step, units)),
        )

        x_np = np.random.random((batch, time_step, embedding_dim))
        y_np_1 = model.predict(x_np)

        time_major = keras.layers.Lambda(lambda t: tf.transpose(t, [1, 0, 2]))(
            x
        )
        rnn_2 = keras.layers.SimpleRNN(
            units, time_major=True, return_sequences=True
        )
        y_2 = rnn_2(time_major)
        y_2 = keras.layers.Lambda(lambda t: tf.transpose(t, [1, 0, 2]))(y_2)

        model_2 = keras.models.Model(x, y_2)
        rnn_2.set_weights(rnn_1.get_weights())

        y_np_2 = model_2.predict(x_np)
        self.assertAllClose(y_np_1, y_np_2, atol=1e-4)

    def test_rnn_cell_with_constants_layer(self):
        # Test basic case.
        x = keras.Input((None, 5))
        c = keras.Input((3,))
        cell = RNNCellWithConstants(32, constant_size=3)
        layer = keras.layers.RNN(cell)
        y = layer(x, constants=c)

        model = keras.models.Model([x, c], y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((6, 5, 5)), np.zeros((6, 3))], np.zeros((6, 32))
        )

        # Test basic case serialization.
        x_np = np.random.random((6, 5, 5))
        c_np = np.random.random((6, 3))
        y_np = model.predict([x_np, c_np])
        weights = model.get_weights()
        config = layer.get_config()
        custom_objects = {"RNNCellWithConstants": RNNCellWithConstants}
        with generic_utils.CustomObjectScope(custom_objects):
            layer = keras.layers.RNN.from_config(config.copy())
        y = layer(x, constants=c)
        model = keras.models.Model([x, c], y)
        model.set_weights(weights)
        y_np_2 = model.predict([x_np, c_np])
        self.assertAllClose(y_np, y_np_2, atol=1e-4)

        # test flat list inputs.
        with generic_utils.CustomObjectScope(custom_objects):
            layer = keras.layers.RNN.from_config(config.copy())
        y = layer([x, c])
        model = keras.models.Model([x, c], y)
        model.set_weights(weights)
        y_np_3 = model.predict([x_np, c_np])
        self.assertAllClose(y_np, y_np_3, atol=1e-4)

        # Test stacking.
        cells = [
            gru.GRUCell(8),
            RNNCellWithConstants(12, constant_size=3),
            RNNCellWithConstants(32, constant_size=3),
        ]
        layer = keras.layers.RNN(cells)
        y = layer(x, constants=c)
        model = keras.models.Model([x, c], y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((6, 5, 5)), np.zeros((6, 3))], np.zeros((6, 32))
        )

        # Test GRUCell reset_after property.
        x = keras.Input((None, 5))
        c = keras.Input((3,))
        cells = [gru.GRUCell(32, reset_after=True)]
        layer = keras.layers.RNN(cells)
        y = layer(x, constants=c)
        model = keras.models.Model([x, c], y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((6, 5, 5)), np.zeros((6, 3))], np.zeros((6, 32))
        )

        # Test stacked RNN serialization
        x_np = np.random.random((6, 5, 5))
        c_np = np.random.random((6, 3))
        y_np = model.predict([x_np, c_np])
        weights = model.get_weights()
        config = layer.get_config()
        with generic_utils.CustomObjectScope(custom_objects):
            layer = keras.layers.RNN.from_config(config.copy())
        y = layer(x, constants=c)
        model = keras.models.Model([x, c], y)
        model.set_weights(weights)
        y_np_2 = model.predict([x_np, c_np])
        self.assertAllClose(y_np, y_np_2, atol=1e-4)

    def test_rnn_cell_with_non_keras_constants(self):
        # Test basic case.
        x = keras.Input((None, 5))
        c = tf.zeros([6, 3], dtype=tf.float32)
        cell = RNNCellWithConstants(32, constant_size=3)
        layer = keras.layers.RNN(cell)
        y = layer(x, constants=c)

        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

        # Test stacking.
        cells = [
            gru.GRUCell(8),
            RNNCellWithConstants(12, constant_size=3),
            RNNCellWithConstants(32, constant_size=3),
        ]
        layer = keras.layers.RNN(cells)
        y = layer(x, constants=c)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    def test_rnn_cell_with_constants_layer_passing_initial_state(self):
        # Test basic case.
        x = keras.Input((None, 5))
        c = keras.Input((3,))
        s = keras.Input((32,))
        cell = RNNCellWithConstants(32, constant_size=3)
        layer = keras.layers.RNN(cell)
        y = layer(x, initial_state=s, constants=c)
        model = keras.models.Model([x, s, c], y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))],
            np.zeros((6, 32)),
        )

        # Test basic case serialization.
        x_np = np.random.random((6, 5, 5))
        s_np = np.random.random((6, 32))
        c_np = np.random.random((6, 3))
        y_np = model.predict([x_np, s_np, c_np])
        weights = model.get_weights()
        config = layer.get_config()
        custom_objects = {"RNNCellWithConstants": RNNCellWithConstants}
        with generic_utils.CustomObjectScope(custom_objects):
            layer = keras.layers.RNN.from_config(config.copy())
        y = layer(x, initial_state=s, constants=c)
        model = keras.models.Model([x, s, c], y)
        model.set_weights(weights)
        y_np_2 = model.predict([x_np, s_np, c_np])
        self.assertAllClose(y_np, y_np_2, atol=1e-4)

        # verify that state is used
        y_np_2_different_s = model.predict([x_np, s_np + 10.0, c_np])
        with self.assertRaises(AssertionError):
            self.assertAllClose(y_np, y_np_2_different_s, atol=1e-4)

        # test flat list inputs
        with generic_utils.CustomObjectScope(custom_objects):
            layer = keras.layers.RNN.from_config(config.copy())
        y = layer([x, s, c])
        model = keras.models.Model([x, s, c], y)
        model.set_weights(weights)
        y_np_3 = model.predict([x_np, s_np, c_np])
        self.assertAllClose(y_np, y_np_3, atol=1e-4)

    def test_rnn_cell_with_non_keras_constants_and_initial_state(self):
        # Test basic case.
        x = keras.Input((None, 5))
        c = tf.zeros([6, 3], dtype=tf.float32)
        s = tf.zeros([6, 32], dtype=tf.float32)
        cell = RNNCellWithConstants(32, constant_size=3)
        layer = keras.layers.RNN(cell)
        y = layer(x, initial_state=s, constants=c)

        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

        # Test stacking.
        cells = [
            gru.GRUCell(8),
            RNNCellWithConstants(12, constant_size=3),
            RNNCellWithConstants(32, constant_size=3),
        ]
        layer = keras.layers.RNN(cells)
        s = [
            tf.zeros([6, 8], dtype=tf.float32),
            tf.zeros([6, 12], dtype=tf.float32),
            tf.zeros([6, 32], dtype=tf.float32),
        ]
        y = layer(x, initial_state=s, constants=c)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    def test_stacked_rnn_attributes(self):
        if tf.executing_eagerly():
            self.skipTest("reduce_sum is not available in eager mode.")

        cells = [keras.layers.LSTMCell(1), keras.layers.LSTMCell(1)]
        layer = keras.layers.RNN(cells)
        layer.build((None, None, 1))

        # Test weights
        self.assertEqual(len(layer.trainable_weights), 6)
        cells[0].trainable = False
        self.assertEqual(len(layer.trainable_weights), 3)
        self.assertEqual(len(layer.non_trainable_weights), 3)

        # Test `get_losses_for` and `losses`
        x = keras.Input((None, 1))
        loss_1 = tf.reduce_sum(x)
        loss_2 = tf.reduce_sum(cells[0].kernel)
        cells[0].add_loss(loss_1, inputs=x)
        cells[0].add_loss(loss_2)
        self.assertEqual(len(layer.losses), 2)
        self.assertEqual(layer.get_losses_for(None), [loss_2])
        self.assertEqual(layer.get_losses_for(x), [loss_1])

        # Test `updates`
        cells = [keras.layers.LSTMCell(1), keras.layers.LSTMCell(1)]
        layer = keras.layers.RNN(cells)
        x = keras.Input((None, 1))
        _ = layer(x)

        update_1 = tf.compat.v1.assign_add(
            cells[0].kernel, x[0, 0, 0] * cells[0].kernel
        )
        update_2 = tf.compat.v1.assign_add(
            cells[0].kernel, tf.ones_like(cells[0].kernel)
        )
        # TODO(b/128682878): Remove when RNNCells are __call__'d.
        with base_layer_utils.call_context().enter(layer, x, True, None):
            cells[0].add_update(update_1)
            cells[0].add_update(update_2)
        self.assertEqual(len(layer.updates), 2)

    def test_rnn_dynamic_trainability(self):
        layer_class = keras.layers.SimpleRNN
        embedding_dim = 4
        units = 3

        layer = layer_class(units)
        layer.build((None, None, embedding_dim))
        self.assertEqual(len(layer.weights), 3)
        self.assertEqual(len(layer.trainable_weights), 3)
        self.assertEqual(len(layer.non_trainable_weights), 0)
        layer.trainable = False
        self.assertEqual(len(layer.weights), 3)
        self.assertEqual(len(layer.trainable_weights), 0)
        self.assertEqual(len(layer.non_trainable_weights), 3)
        layer.trainable = True
        self.assertEqual(len(layer.weights), 3)
        self.assertEqual(len(layer.trainable_weights), 3)
        self.assertEqual(len(layer.non_trainable_weights), 0)

    @parameterized.parameters(
        [keras.layers.SimpleRNN, keras.layers.GRU, keras.layers.LSTM]
    )
    def test_rnn_cell_trainability(self, layer_cls):
        # https://github.com/tensorflow/tensorflow/issues/32369.
        layer = layer_cls(3, trainable=False)
        self.assertFalse(layer.cell.trainable)

        layer.trainable = True
        self.assertTrue(layer.cell.trainable)

    def test_state_reuse_with_dropout(self):
        layer_class = keras.layers.SimpleRNN
        embedding_dim = 4
        units = 3
        timesteps = 2
        num_samples = 2

        input1 = keras.Input(
            batch_shape=(num_samples, timesteps, embedding_dim)
        )
        layer = layer_class(
            units, return_state=True, return_sequences=True, dropout=0.2
        )
        state = layer(input1)[1:]

        input2 = keras.Input(
            batch_shape=(num_samples, timesteps, embedding_dim)
        )
        output = layer_class(units)(input2, initial_state=state)
        model = keras.Model([input1, input2], output)

        inputs = [
            np.random.random((num_samples, timesteps, embedding_dim)),
            np.random.random((num_samples, timesteps, embedding_dim)),
        ]
        model.predict(inputs)

    def test_builtin_and_custom_rnn_cell_serialization(self):
        @keras.utils.generic_utils.register_keras_serializable(
            package="TestOnly"
        )
        class CustomRNNCell(keras.layers.Layer):
            def __init__(self, units, **kwargs):
                self.units = units
                self.state_size = units
                super().__init__(**kwargs)

            def build(self, input_shape):
                self.kernel = self.add_weight(
                    shape=(input_shape[-1], self.units),
                    initializer="uniform",
                    name="kernel",
                )
                self.recurrent_kernel = self.add_weight(
                    shape=(self.units, self.units),
                    initializer="uniform",
                    name="recurrent_kernel",
                )
                self.built = True

            def call(self, inputs, states):
                prev_output = states[0]
                h = keras.backend.dot(inputs, self.kernel)
                output = h + keras.backend.dot(
                    prev_output, self.recurrent_kernel
                )
                return output, [output]

            def get_config(self):
                config = {"units": self.units}
                base_config = super().get_config()
                return dict(list(base_config.items()) + list(config.items()))

        for cell_class in [
            keras.layers.SimpleRNNCell,
            keras.layers.GRUCell,
            keras.layers.LSTMCell,
            CustomRNNCell,
        ]:
            # Test basic case.
            x = keras.Input((None, 5))
            cell = cell_class(32)
            layer = keras.layers.RNN(cell)
            y = layer(x)
            model = keras.models.Model(x, y)
            model.compile(
                optimizer="rmsprop",
                loss="mse",
                run_eagerly=test_utils.should_run_eagerly(),
            )

            # Test basic case serialization.
            x_np = np.random.random((6, 5, 5))
            y_np = model.predict(x_np)
            weights = model.get_weights()
            config = layer.get_config()
            layer = keras.layers.RNN.from_config(config)
            y = layer(x)
            model = keras.models.Model(x, y)
            model.set_weights(weights)
            y_np_2 = model.predict(x_np)
            self.assertAllClose(y_np, y_np_2, atol=1e-4)

            # Test stacking.
            cells = [cell_class(8), cell_class(12), cell_class(32)]
            layer = keras.layers.RNN(cells)
            y = layer(x)
            model = keras.models.Model(x, y)
            model.compile(
                optimizer="rmsprop",
                loss="mse",
                run_eagerly=test_utils.should_run_eagerly(),
            )

            # Test stacked RNN serialization.
            x_np = np.random.random((6, 5, 5))
            y_np = model.predict(x_np)
            weights = model.get_weights()
            config = layer.get_config()
            layer = keras.layers.RNN.from_config(config)
            y = layer(x)
            model = keras.models.Model(x, y)
            model.set_weights(weights)
            y_np_2 = model.predict(x_np)
            self.assertAllClose(y_np, y_np_2, atol=1e-4)

    @parameterized.named_parameters(
        *test_utils.generate_combinations_with_testcase_name(
            layer=[
                keras.layers.SimpleRNN,
                gru_v1.GRU,
                lstm_v1.LSTM,
                gru.GRU,
                lstm.LSTM,
            ],
            unroll=[True, False],
        )
    )
    def test_rnn_dropout(self, layer, unroll):
        rnn_layer = layer(3, dropout=0.1, recurrent_dropout=0.1, unroll=unroll)
        if not unroll:
            x = keras.Input((None, 5))
        else:
            x = keras.Input((5, 5))
        y = rnn_layer(x)
        model = keras.models.Model(x, y)
        model.compile("sgd", "mse", run_eagerly=test_utils.should_run_eagerly())
        x_np = np.random.random((6, 5, 5))
        y_np = np.random.random((6, 3))
        model.train_on_batch(x_np, y_np)

    @parameterized.named_parameters(
        *test_utils.generate_combinations_with_testcase_name(
            cell=[
                keras.layers.SimpleRNNCell,
                keras.layers.GRUCell,
                keras.layers.LSTMCell,
            ],
            unroll=[True, False],
        )
    )
    def test_stacked_rnn_dropout(self, cell, unroll):
        cells = [
            cell(3, dropout=0.1, recurrent_dropout=0.1),
            cell(3, dropout=0.1, recurrent_dropout=0.1),
        ]
        layer = keras.layers.RNN(cells, unroll=unroll)

        if not unroll:
            x = keras.Input((None, 5))
        else:
            x = keras.Input((5, 5))
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile("sgd", "mse", run_eagerly=test_utils.should_run_eagerly())
        x_np = np.random.random((6, 5, 5))
        y_np = np.random.random((6, 3))
        model.train_on_batch(x_np, y_np)

    def test_dropout_mask_reuse(self):
        # The layer is created with recurrent_initializer = zero, so that the
        # the recurrent state won't affect the output. By doing this, we can
        # verify the output and see if the same mask is applied to for each
        # timestep.
        layer_1 = keras.layers.SimpleRNN(
            3,
            dropout=0.5,
            kernel_initializer="ones",
            recurrent_initializer="zeros",
            return_sequences=True,
            unroll=True,
        )
        layer_2 = keras.layers.RNN(
            keras.layers.SimpleRNNCell(
                3,
                dropout=0.5,
                kernel_initializer="ones",
                recurrent_initializer="zeros",
            ),
            return_sequences=True,
            unroll=True,
        )
        layer_3 = keras.layers.RNN(
            [
                keras.layers.SimpleRNNCell(
                    3,
                    dropout=0.5,
                    kernel_initializer="ones",
                    recurrent_initializer="zeros",
                ),
                keras.layers.SimpleRNNCell(
                    3,
                    dropout=0.5,
                    kernel_initializer="ones",
                    recurrent_initializer="zeros",
                ),
            ],
            return_sequences=True,
            unroll=True,
        )

        def verify(rnn_layer):
            inputs = tf.constant(1.0, shape=(6, 2, 5))
            out = rnn_layer(inputs, training=True)
            if not tf.executing_eagerly():
                self.evaluate(tf.compat.v1.global_variables_initializer())
            batch_1 = self.evaluate(out)
            batch_1_t0, batch_1_t1 = batch_1[:, 0, :], batch_1[:, 1, :]
            self.assertAllClose(batch_1_t0, batch_1_t1)

            # This simulate the layer called with multiple batches in eager mode
            if tf.executing_eagerly():
                out2 = rnn_layer(inputs, training=True)
            else:
                out2 = out
            batch_2 = self.evaluate(out2)
            batch_2_t0, batch_2_t1 = batch_2[:, 0, :], batch_2[:, 1, :]
            self.assertAllClose(batch_2_t0, batch_2_t1)

            # Also validate that different dropout is used by between batches.
            self.assertNotAllClose(batch_1_t0, batch_2_t0)
            self.assertNotAllClose(batch_1_t1, batch_2_t1)

        for l in [layer_1, layer_2, layer_3]:
            verify(l)

    def test_stacked_rnn_compute_output_shape(self):
        cells = [keras.layers.LSTMCell(3), keras.layers.LSTMCell(6)]
        embedding_dim = 4
        timesteps = 2
        layer = keras.layers.RNN(
            cells, return_state=True, return_sequences=True
        )
        output_shape = layer.compute_output_shape(
            (None, timesteps, embedding_dim)
        )
        expected_output_shape = [
            (None, timesteps, 6),
            (None, 3),
            (None, 3),
            (None, 6),
            (None, 6),
        ]
        self.assertEqual(
            [tuple(o.as_list()) for o in output_shape], expected_output_shape
        )

        # Test reverse_state_order = True for stacked cell.
        stacked_cell = keras.layers.StackedRNNCells(
            cells, reverse_state_order=True
        )
        layer = keras.layers.RNN(
            stacked_cell, return_state=True, return_sequences=True
        )
        output_shape = layer.compute_output_shape(
            (None, timesteps, embedding_dim)
        )
        expected_output_shape = [
            (None, timesteps, 6),
            (None, 6),
            (None, 6),
            (None, 3),
            (None, 3),
        ]
        self.assertEqual(
            [tuple(o.as_list()) for o in output_shape], expected_output_shape
        )

    def test_stacked_rnn_with_training_param(self):
        # See https://github.com/tensorflow/tensorflow/issues/32586

        class CellWrapper(keras.layers.AbstractRNNCell):
            def __init__(self, cell):
                super().__init__()
                self.cell = cell

            @property
            def state_size(self):
                return self.cell.state_size

            @property
            def output_size(self):
                return self.cell.output_size

            def build(self, input_shape):
                self.cell.build(input_shape)
                self.built = True

            def get_initial_state(
                self, inputs=None, batch_size=None, dtype=None
            ):
                return self.cell.get_initial_state(
                    inputs=inputs, batch_size=batch_size, dtype=dtype
                )

            def call(self, inputs, states, training=None, **kwargs):
                assert training is not None
                return self.cell(inputs, states=states, training=training)

        cell = keras.layers.LSTMCell(32)
        cell = CellWrapper(cell)
        cell = keras.layers.StackedRNNCells([cell])

        rnn = keras.layers.RNN(cell)
        inputs = np.ones((8, 4, 16), dtype=np.float32)
        rnn(inputs, training=True)

    def test_stacked_rnn_with_nested_cell(self):
        batch = 10
        t = 5
        i1, i2, i3 = 3, 4, 5
        o11, o12, o13 = 2, 3, 4
        o21, o22, o23 = 4, 5, 6

        # test 1: use_tuple=False
        cells = [NestedCell(o11, o12, o13), NestedCell(o21, o22, o23)]
        rnn = keras.layers.RNN(cells, return_sequences=True, return_state=True)

        input_1 = keras.Input((t, i1))
        input_2 = keras.Input((t, i2, i3))

        output1, output2, state1, state2 = rnn((input_1, input_2))
        s11, s12 = state1
        s21, s22 = state2

        self.assertEqual(output1.shape.as_list(), [None, t, o21])
        self.assertEqual(output2.shape.as_list(), [None, t, o22, o23])
        self.assertEqual(s11.shape.as_list(), [None, o11])
        self.assertEqual(s12.shape.as_list(), [None, o12, o13])
        self.assertEqual(s21.shape.as_list(), [None, o21])
        self.assertEqual(s22.shape.as_list(), [None, o22, o23])

        model = keras.models.Model([input_1, input_2], [output1, output2])
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((batch, t, i1)), np.zeros((batch, t, i2, i3))],
            [np.zeros((batch, t, o21)), np.zeros((batch, t, o22, o23))],
        )
        self.assertEqual(
            model.output_shape, [(None, t, o21), (None, t, o22, o23)]
        )

        # test 2: use_tuple=True
        cells = [
            NestedCell(o11, o12, o13, use_tuple=True),
            NestedCell(o21, o22, o23),
        ]

        rnn = keras.layers.RNN(cells, return_sequences=True, return_state=True)

        input_1 = keras.Input((t, i1))
        input_2 = keras.Input((t, i2, i3))

        output1, output2, state1, state2 = rnn(
            NestedInput(t1=input_1, t2=input_2)
        )
        s11, s12 = state1
        s21, s22 = state2

        self.assertEqual(output1.shape.as_list(), [None, t, o21])
        self.assertEqual(output2.shape.as_list(), [None, t, o22, o23])
        self.assertEqual(s11.shape.as_list(), [None, o11])
        self.assertEqual(s12.shape.as_list(), [None, o12, o13])
        self.assertEqual(s21.shape.as_list(), [None, o21])
        self.assertEqual(s22.shape.as_list(), [None, o22, o23])

        model = keras.models.Model([input_1, input_2], [output1, output2])
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((batch, t, i1)), np.zeros((batch, t, i2, i3))],
            [np.zeros((batch, t, o21)), np.zeros((batch, t, o22, o23))],
        )
        self.assertEqual(
            model.output_shape, [(None, t, o21), (None, t, o22, o23)]
        )

    def test_trackable_dependencies(self):
        rnn = keras.layers.SimpleRNN
        x = np.random.random((2, 2, 2))
        y = np.random.random((2, 2))
        model = keras.models.Sequential()
        model.add(rnn(2))
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.fit(x, y, epochs=1, batch_size=1)

        # check whether the model variables are present in the
        # trackable list of objects
        checkpointed_objects = {
            id(o) for o in trackable_util.list_objects(model)
        }
        for v in model.variables:
            self.assertIn(id(v), checkpointed_objects)

    def test_high_dimension_RNN(self):
        # Basic test case.
        unit_a = 10
        unit_b = 20
        input_a = 5
        input_b = 10
        batch = 32
        time_step = 4

        cell = Minimal2DRNNCell(unit_a, unit_b)
        x = keras.Input((None, input_a, input_b))
        layer = keras.layers.RNN(cell)
        y = layer(x)

        self.assertEqual(cell.state_size.as_list(), [unit_a, unit_b])

        if not tf.executing_eagerly():
            init_state = layer.get_initial_state(x)
            self.assertEqual(len(init_state), 1)
            self.assertEqual(
                init_state[0].shape.as_list(), [None, unit_a, unit_b]
            )

        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, time_step, input_a, input_b)),
            np.zeros((batch, unit_a, unit_b)),
        )
        self.assertEqual(model.output_shape, (None, unit_a, unit_b))

        # Test stacking.
        cells = [
            Minimal2DRNNCell(unit_a, unit_b),
            Minimal2DRNNCell(unit_a * 2, unit_b * 2),
            Minimal2DRNNCell(unit_a * 4, unit_b * 4),
        ]
        layer = keras.layers.RNN(cells)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, time_step, input_a, input_b)),
            np.zeros((batch, unit_a * 4, unit_b * 4)),
        )
        self.assertEqual(model.output_shape, (None, unit_a * 4, unit_b * 4))

    def test_high_dimension_RNN_with_init_state(self):
        unit_a = 10
        unit_b = 20
        input_a = 5
        input_b = 10
        batch = 32
        time_step = 4

        # Basic test case.
        cell = Minimal2DRNNCell(unit_a, unit_b)
        x = keras.Input((None, input_a, input_b))
        s = keras.Input((unit_a, unit_b))
        layer = keras.layers.RNN(cell)
        y = layer(x, initial_state=s)

        model = keras.models.Model([x, s], y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [
                np.zeros((batch, time_step, input_a, input_b)),
                np.zeros((batch, unit_a, unit_b)),
            ],
            np.zeros((batch, unit_a, unit_b)),
        )
        self.assertEqual(model.output_shape, (None, unit_a, unit_b))

        # Bad init state shape.
        bad_shape_a = unit_a * 2
        bad_shape_b = unit_b * 2
        cell = Minimal2DRNNCell(unit_a, unit_b)
        x = keras.Input((None, input_a, input_b))
        s = keras.Input((bad_shape_a, bad_shape_b))
        layer = keras.layers.RNN(cell)
        with self.assertRaisesWithPredicateMatch(
            ValueError, "however `cell.state_size` is"
        ):
            layer(x, initial_state=s)

    def test_inconsistent_output_state_size(self):
        batch = 32
        time_step = 4
        state_size = 5
        input_size = 6
        cell = PlusOneRNNCell(state_size)
        x = keras.Input((None, input_size))
        layer = keras.layers.RNN(cell)
        y = layer(x)

        self.assertEqual(cell.state_size, state_size)
        if not tf.executing_eagerly():
            init_state = layer.get_initial_state(x)
            self.assertEqual(len(init_state), 1)
            self.assertEqual(init_state[0].shape.as_list(), [None, state_size])

        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, time_step, input_size)),
            np.zeros((batch, input_size)),
        )
        self.assertEqual(model.output_shape, (None, input_size))

    def test_get_initial_state(self):
        cell = keras.layers.SimpleRNNCell(5)
        with self.assertRaisesRegex(
            ValueError, "batch_size and dtype cannot be None"
        ):
            cell.get_initial_state(None, None, None)

        if not tf.executing_eagerly():
            inputs = keras.Input((None, 10))
            initial_state = cell.get_initial_state(inputs, None, None)
            self.assertEqual(initial_state.shape.as_list(), [None, 5])
            self.assertEqual(initial_state.dtype, inputs.dtype)

            batch = tf.shape(inputs)[0]
            dtype = inputs.dtype
            initial_state = cell.get_initial_state(None, batch, dtype)
            self.assertEqual(initial_state.shape.as_list(), [None, 5])
            self.assertEqual(initial_state.dtype, inputs.dtype)
        else:
            batch = 8
            inputs = np.random.random((batch, 10))
            initial_state = cell.get_initial_state(inputs, None, None)
            self.assertEqual(initial_state.shape.as_list(), [8, 5])
            self.assertEqual(initial_state.dtype, inputs.dtype)

            dtype = inputs.dtype
            initial_state = cell.get_initial_state(None, batch, dtype)
            self.assertEqual(initial_state.shape.as_list(), [batch, 5])
            self.assertEqual(initial_state.dtype, inputs.dtype)

    @parameterized.parameters([True, False])
    def test_nested_input_output(self, stateful):
        batch = 10
        t = 5
        i1, i2, i3 = 3, 4, 5
        o1, o2, o3 = 2, 3, 4

        cell = NestedCell(o1, o2, o3)
        rnn = keras.layers.RNN(cell, stateful=stateful)

        batch_size = batch if stateful else None
        input_1 = keras.Input((t, i1), batch_size=batch_size)
        input_2 = keras.Input((t, i2, i3), batch_size=batch_size)

        outputs = rnn((input_1, input_2))

        self.assertEqual(len(outputs), 2)
        self.assertEqual(outputs[0].shape.as_list(), [batch_size, o1])
        self.assertEqual(outputs[1].shape.as_list(), [batch_size, o2, o3])

        model = keras.models.Model((input_1, input_2), outputs)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((batch, t, i1)), np.zeros((batch, t, i2, i3))],
            [np.zeros((batch, o1)), np.zeros((batch, o2, o3))],
        )
        self.assertEqual(
            model.output_shape, [(batch_size, o1), (batch_size, o2, o3)]
        )

        cell = NestedCell(o1, o2, o3, use_tuple=True)

        rnn = keras.layers.RNN(cell, stateful=stateful)

        input_1 = keras.Input((t, i1), batch_size=batch_size)
        input_2 = keras.Input((t, i2, i3), batch_size=batch_size)

        outputs = rnn(NestedInput(t1=input_1, t2=input_2))

        self.assertEqual(len(outputs), 2)
        self.assertEqual(outputs[0].shape.as_list(), [batch_size, o1])
        self.assertEqual(outputs[1].shape.as_list(), [batch_size, o2, o3])

        model = keras.models.Model([input_1, input_2], outputs)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((batch, t, i1)), np.zeros((batch, t, i2, i3))],
            [np.zeros((batch, o1)), np.zeros((batch, o2, o3))],
        )
        self.assertEqual(
            model.output_shape, [(batch_size, o1), (batch_size, o2, o3)]
        )

    def test_nested_input_output_with_state(self):
        batch = 10
        t = 5
        i1, i2, i3 = 3, 4, 5
        o1, o2, o3 = 2, 3, 4

        cell = NestedCell(o1, o2, o3)
        rnn = keras.layers.RNN(cell, return_sequences=True, return_state=True)

        input_1 = keras.Input((t, i1))
        input_2 = keras.Input((t, i2, i3))

        output1, output2, s1, s2 = rnn((input_1, input_2))

        self.assertEqual(output1.shape.as_list(), [None, t, o1])
        self.assertEqual(output2.shape.as_list(), [None, t, o2, o3])
        self.assertEqual(s1.shape.as_list(), [None, o1])
        self.assertEqual(s2.shape.as_list(), [None, o2, o3])

        model = keras.models.Model([input_1, input_2], [output1, output2])
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((batch, t, i1)), np.zeros((batch, t, i2, i3))],
            [np.zeros((batch, t, o1)), np.zeros((batch, t, o2, o3))],
        )
        self.assertEqual(model.output_shape, [(None, t, o1), (None, t, o2, o3)])

        cell = NestedCell(o1, o2, o3, use_tuple=True)

        rnn = keras.layers.RNN(cell, return_sequences=True, return_state=True)

        input_1 = keras.Input((t, i1))
        input_2 = keras.Input((t, i2, i3))

        output1, output2, s1, s2 = rnn(NestedInput(t1=input_1, t2=input_2))

        self.assertEqual(output1.shape.as_list(), [None, t, o1])
        self.assertEqual(output2.shape.as_list(), [None, t, o2, o3])
        self.assertEqual(s1.shape.as_list(), [None, o1])
        self.assertEqual(s2.shape.as_list(), [None, o2, o3])

        model = keras.models.Model([input_1, input_2], [output1, output2])
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [np.zeros((batch, t, i1)), np.zeros((batch, t, i2, i3))],
            [np.zeros((batch, t, o1)), np.zeros((batch, t, o2, o3))],
        )
        self.assertEqual(model.output_shape, [(None, t, o1), (None, t, o2, o3)])

    def test_nest_input_output_with_init_state(self):
        batch = 10
        t = 5
        i1, i2, i3 = 3, 4, 5
        o1, o2, o3 = 2, 3, 4

        cell = NestedCell(o1, o2, o3)
        rnn = keras.layers.RNN(cell, return_sequences=True, return_state=True)

        input_1 = keras.Input((t, i1))
        input_2 = keras.Input((t, i2, i3))
        init_s1 = keras.Input((o1,))
        init_s2 = keras.Input((o2, o3))

        output1, output2, s1, s2 = rnn(
            (input_1, input_2), initial_state=(init_s1, init_s2)
        )

        self.assertEqual(output1.shape.as_list(), [None, t, o1])
        self.assertEqual(output2.shape.as_list(), [None, t, o2, o3])
        self.assertEqual(s1.shape.as_list(), [None, o1])
        self.assertEqual(s2.shape.as_list(), [None, o2, o3])

        model = keras.models.Model(
            [input_1, input_2, init_s1, init_s2], [output1, output2]
        )
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [
                np.zeros((batch, t, i1)),
                np.zeros((batch, t, i2, i3)),
                np.zeros((batch, o1)),
                np.zeros((batch, o2, o3)),
            ],
            [np.zeros((batch, t, o1)), np.zeros((batch, t, o2, o3))],
        )
        self.assertEqual(model.output_shape, [(None, t, o1), (None, t, o2, o3)])

        cell = NestedCell(o1, o2, o3, use_tuple=True)

        rnn = keras.layers.RNN(cell, return_sequences=True, return_state=True)

        input_1 = keras.Input((t, i1))
        input_2 = keras.Input((t, i2, i3))
        init_s1 = keras.Input((o1,))
        init_s2 = keras.Input((o2, o3))
        init_state = NestedState(s1=init_s1, s2=init_s2)

        output1, output2, s1, s2 = rnn(
            NestedInput(t1=input_1, t2=input_2), initial_state=init_state
        )

        self.assertEqual(output1.shape.as_list(), [None, t, o1])
        self.assertEqual(output2.shape.as_list(), [None, t, o2, o3])
        self.assertEqual(s1.shape.as_list(), [None, o1])
        self.assertEqual(s2.shape.as_list(), [None, o2, o3])

        model = keras.models.Model(
            [input_1, input_2, init_s1, init_s2], [output1, output2]
        )
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            [
                np.zeros((batch, t, i1)),
                np.zeros((batch, t, i2, i3)),
                np.zeros((batch, o1)),
                np.zeros((batch, o2, o3)),
            ],
            [np.zeros((batch, t, o1)), np.zeros((batch, t, o2, o3))],
        )
        self.assertEqual(model.output_shape, [(None, t, o1), (None, t, o2, o3)])

    def test_masking_rnn_with_output_and_states(self):
        class Cell(keras.layers.Layer):
            def __init__(self):
                self.state_size = None
                self.output_size = None
                super().__init__()

            def build(self, input_shape):
                self.state_size = input_shape[-1]
                self.output_size = input_shape[-1]

            def call(self, inputs, states):
                return inputs, [s + 1 for s in states]

        x = keras.Input((3, 1), name="x")
        x_masked = keras.layers.Masking()(x)
        s_0 = keras.Input((1,), name="s_0")
        y, s = keras.layers.RNN(Cell(), return_state=True)(
            x_masked, initial_state=s_0
        )
        model = keras.models.Model([x, s_0], [y, s])
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )

        # last time step masked
        x_np = np.array([[[1.0], [2.0], [0.0]]])
        s_0_np = np.array([[10.0]])
        y_np, s_np = model.predict([x_np, s_0_np])

        # 1 is added to initial state two times
        self.assertAllClose(s_np, s_0_np + 2)
        # Expect last output to be the same as last output before masking
        self.assertAllClose(y_np, x_np[:, 1, :])

    def test_zero_output_for_masking(self):

        for unroll in [True, False]:
            cell = keras.layers.SimpleRNNCell(5)
            x = keras.Input((5, 5))
            mask = keras.layers.Masking()
            layer = keras.layers.RNN(
                cell,
                return_sequences=True,
                zero_output_for_mask=True,
                unroll=unroll,
            )
            masked_input = mask(x)
            y = layer(masked_input)
            model = keras.models.Model(x, y)
            model.compile(
                optimizer="rmsprop",
                loss="mse",
                run_eagerly=test_utils.should_run_eagerly(),
            )

            np_x = np.ones((6, 5, 5))
            result_1 = model.predict(np_x)

            # set the time 4 and 5 for last record to be zero (masked).
            np_x[5, 3:] = 0
            result_2 = model.predict(np_x)

            # expect the result_2 has same output, except the time 4,5 for last
            # record.
            result_1[5, 3:] = 0
            self.assertAllClose(result_1, result_2)

    def test_unroll_single_step(self):
        """Even if the time dimension is only one, we should be able to
        unroll."""
        cell = keras.layers.SimpleRNNCell(5)
        x = keras.Input((1, 5))
        layer = keras.layers.RNN(cell, return_sequences=True, unroll=True)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )

        np_x = np.ones((6, 1, 5))
        result = model.predict(np_x)
        self.assertEqual((6, 1, 5), result.shape)

    def test_unroll_zero_step(self):
        """If the time dimension is None, we should fail to unroll."""
        cell = keras.layers.SimpleRNNCell(5)
        x = keras.Input((None, 5))
        layer = keras.layers.RNN(cell, return_sequences=True, unroll=True)
        with self.assertRaisesRegex(ValueError, "Cannot unroll a RNN.*"):
            layer(x)

    def test_full_input_spec(self):
        # See https://github.com/tensorflow/tensorflow/issues/25985
        inputs = keras.layers.Input(batch_shape=(1, 1, 1))
        state_h = keras.layers.Input(batch_shape=(1, 1))
        state_c = keras.layers.Input(batch_shape=(1, 1))
        states = [state_h, state_c]
        decoder_out = keras.layers.LSTM(1, stateful=True)(
            inputs, initial_state=states
        )
        model = keras.Model([inputs, state_h, state_c], decoder_out)
        output1 = model.predict(
            [np.ones((1, 1, 1)), np.ones((1, 1)), np.ones((1, 1))]
        )
        output2 = model.predict(
            [np.ones((1, 1, 1)), np.ones((1, 1)), np.ones((1, 1))]
        )
        model.reset_states()
        output3 = model.predict(
            [np.ones((1, 1, 1)), np.ones((1, 1)), np.ones((1, 1))]
        )
        self.assertAllClose(output1, output3)
        self.assertNotAllClose(output1, output2)

    def test_reset_states(self):
        # See https://github.com/tensorflow/tensorflow/issues/25852
        with self.assertRaisesRegex(
            ValueError, "it needs to know its batch size"
        ):
            simple_rnn = keras.layers.SimpleRNN(1, stateful=True)
            simple_rnn.reset_states()

        with self.assertRaisesRegex(
            ValueError, "it needs to know its batch size"
        ):
            cell = Minimal2DRNNCell(1, 2)
            custom_rnn = keras.layers.RNN(cell, stateful=True)
            custom_rnn.reset_states()

    @parameterized.parameters(
        [
            keras.layers.SimpleRNNCell,
            keras.layers.GRUCell,
            keras.layers.LSTMCell,
        ]
    )
    def test_stateful_rnn_with_stacking(self, cell):
        # See https://github.com/tensorflow/tensorflow/issues/28614.
        batch = 12
        timesteps = 10
        input_dim = 8
        output_dim = 64
        cells = [cell(32), cell(64)]
        x = keras.Input(batch_shape=(batch, None, input_dim))
        layer = keras.layers.RNN(cells, stateful=True)
        y = layer(x)

        model = keras.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, timesteps, input_dim)),
            np.zeros((batch, output_dim)),
        )
        model.predict(np.ones((batch, timesteps, input_dim)))

        model.reset_states()
        model.predict(np.ones((batch, timesteps, input_dim)))

        new_states = tf.nest.map_structure(
            lambda s: np.ones((batch, s)), layer.cell.state_size
        )
        layer.reset_states(new_states)
        model.predict(np.ones((batch, timesteps, input_dim)))

    def test_stateful_rnn_with_initial_state(self):
        # See https://github.com/tensorflow/tensorflow/issues/32299.
        batch = 12
        timesteps = 1
        input_dim = 8
        output_dim = 16

        test_inputs = np.full((batch, timesteps, input_dim), 0.5)

        def make_model(stateful=False, with_initial_state=False):
            input_layer = keras.Input(shape=(None, input_dim), batch_size=batch)
            if with_initial_state:
                initial_states = keras.backend.constant(
                    np.ones((batch, output_dim))
                )
            else:
                initial_states = None
            rnn_output = keras.layers.GRU(
                units=output_dim, return_sequences=True, stateful=stateful
            )(input_layer, initial_state=initial_states)
            model = keras.Model(input_layer, rnn_output)
            model.compile(
                optimizer="rmsprop",
                loss="mse",
                run_eagerly=test_utils.should_run_eagerly(),
            )
            return model

        # Define a model with a constant state initialization
        model = make_model(stateful=True, with_initial_state=True)
        layer_weights = model.layers[1].get_weights()

        model.reset_states()
        predict_1 = model.predict(test_inputs)
        predict_2 = model.predict(test_inputs)

        model.reset_states()
        predict_3 = model.predict(test_inputs)

        # predict 1 and 2 should be different since the batch 2 should use the
        # state from batch 1 as the initial state.
        self.assertNotAllClose(predict_1, predict_2)
        self.assertAllClose(predict_1, predict_3)

        # Create a new model with same weights but without initial states. Make
        # sure the predict value is different from the model with non-zero
        # initial state.
        model_2 = make_model(stateful=True, with_initial_state=False)
        model_2.layers[1].set_weights(layer_weights)

        model_2.reset_states()
        predict_4 = model_2.predict(test_inputs)
        predict_5 = model_2.predict(test_inputs)
        self.assertNotAllClose(predict_1, predict_4)
        self.assertNotAllClose(predict_4, predict_5)

        # Create models with stateful=False, and make sure they handle init
        # state correctly.
        model_3 = make_model(stateful=False, with_initial_state=True)
        model_3.layers[1].set_weights(layer_weights)

        model_3.reset_states()
        predict_6 = model_3.predict(test_inputs)
        predict_7 = model_3.predict(test_inputs)
        self.assertAllClose(predict_1, predict_6)
        self.assertAllClose(predict_6, predict_7)

    def test_stateful_rnn_with_customized_get_initial_state(self):
        class TestCell(keras.layers.AbstractRNNCell):

            state_size = 1
            output_size = 2

            def get_initial_state(
                self, inputs=None, batch_size=None, dtype=None
            ):
                return np.ones((batch_size, 1), dtype=dtype)

            def call(self, inputs, states):
                return inputs, states

        layer = keras.layers.RNN(TestCell(), stateful=True, return_state=True)
        inputs = keras.Input(shape=(10, 2), batch_size=4)
        model = keras.Model(inputs, layer(inputs))
        x = np.ones((4, 10, 2), dtype=np.float32)
        output, state = model.predict(x)
        self.assertAllClose(output, np.ones((4, 2)))
        self.assertAllClose(state, np.ones((4, 1)))

    def test_input_dim_length(self):
        simple_rnn = keras.layers.SimpleRNN(5, input_length=10, input_dim=8)
        self.assertEqual(simple_rnn._batch_input_shape, (None, 10, 8))

        simple_rnn = keras.layers.SimpleRNN(5, input_dim=8)
        self.assertEqual(simple_rnn._batch_input_shape, (None, None, 8))

        simple_rnn = keras.layers.SimpleRNN(5, input_length=10)
        self.assertEqual(simple_rnn._batch_input_shape, (None, 10, None))

    @parameterized.parameters(
        [
            keras.layers.SimpleRNNCell,
            keras.layers.GRUCell,
            keras.layers.LSTMCell,
        ]
    )
    def test_state_spec_with_stack_cell(self, cell):
        # See https://github.com/tensorflow/tensorflow/issues/27817 for more
        # detail.
        batch = 12
        timesteps = 10
        input_dim = 8
        output_dim = 8

        def create_cell():
            return [cell(output_dim), cell(output_dim), cell(output_dim)]

        inputs = keras.Input((timesteps, input_dim))
        encoder_output = keras.layers.RNN(create_cell(), return_state=True)(
            inputs
        )

        states = encoder_output[1:]

        decoder_output = keras.layers.RNN(create_cell())(
            inputs, initial_state=states
        )

        model = keras.models.Model(inputs, decoder_output)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(
            np.zeros((batch, timesteps, input_dim)),
            np.zeros((batch, output_dim)),
        )
        model.predict(np.ones((batch, timesteps, input_dim)))

    @parameterized.named_parameters(
        *test_utils.generate_combinations_with_testcase_name(
            layer=[
                keras.layers.SimpleRNN,
                gru_v1.GRU,
                lstm_v1.LSTM,
                gru.GRU,
                lstm.LSTM,
            ]
        )
    )
    def test_rnn_with_ragged_input(self, layer):
        ragged_data = tf.ragged.constant(
            [
                [[1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 2.0, 3.0, 1.0, 1.0]],
                [[2.0, 4.0, 1.0, 3.0, 1.0]],
                [
                    [2.0, 3.0, 4.0, 1.0, 5.0],
                    [2.0, 3.0, 1.0, 1.0, 1.0],
                    [1.0, 2.0, 3.0, 4.0, 5.0],
                ],
            ],
            ragged_rank=1,
        )
        label_data = np.array([[1, 0, 1], [1, 1, 0], [0, 0, 1]])

        # Test results in feed forward
        np.random.seed(100)
        rnn_layer = layer(4, activation="sigmoid")

        x_ragged = keras.Input(shape=(None, 5), ragged=True)
        y_ragged = rnn_layer(x_ragged)
        model = keras.models.Model(x_ragged, y_ragged)
        output_ragged = model.predict(ragged_data, steps=1)

        x_dense = keras.Input(shape=(3, 5))
        masking = keras.layers.Masking()(x_dense)
        y_dense = rnn_layer(masking)
        model_2 = keras.models.Model(x_dense, y_dense)
        dense_data = ragged_data.to_tensor()
        output_dense = model_2.predict(dense_data, steps=1)

        self.assertAllClose(output_dense, output_ragged)

        # Test results with go backwards
        np.random.seed(200)
        back_rnn_layer = layer(8, go_backwards=True, activation="sigmoid")

        x_ragged = keras.Input(shape=(None, 5), ragged=True)
        y_ragged = back_rnn_layer(x_ragged)
        model = keras.models.Model(x_ragged, y_ragged)
        output_ragged = model.predict(ragged_data, steps=1)

        x_dense = keras.Input(shape=(3, 5))
        masking = keras.layers.Masking()(x_dense)
        y_dense = back_rnn_layer(masking)
        model_2 = keras.models.Model(x_dense, y_dense)
        dense_data = ragged_data.to_tensor()
        output_dense = model_2.predict(dense_data, steps=1)

        self.assertAllClose(output_dense, output_ragged)

        # Test densification of the ragged input
        dense_tensor, row_lengths = keras.backend.convert_inputs_if_ragged(
            ragged_data
        )
        self.assertAllClose(dense_data, dense_tensor)

        # Test optional params, all should work except unrolling
        inputs = keras.Input(shape=(None, 5), dtype=tf.float32, ragged=True)
        custom_rnn_layer = layer(
            3, zero_output_for_mask=True, dropout=0.1, use_bias=True
        )
        outputs = custom_rnn_layer(inputs)
        model = keras.models.Model(inputs, outputs)
        model.compile(
            optimizer="sgd",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(ragged_data, label_data)

        # Test stateful and full shape specification
        inputs = keras.Input(
            shape=(None, 5), batch_size=3, dtype=tf.float32, ragged=True
        )
        stateful_rnn_layer = layer(3, stateful=True)
        outputs = stateful_rnn_layer(inputs)
        model = keras.models.Model(inputs, outputs)
        model.compile(
            optimizer="sgd",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(ragged_data, label_data)

        # Must raise error when unroll is set to True
        unroll_rnn_layer = layer(3, unroll=True)
        with self.assertRaisesRegex(
            ValueError, "The input received contains RaggedTensors *"
        ):
            unroll_rnn_layer(inputs)

        # Check if return sequences outputs are correct
        np.random.seed(100)
        returning_rnn_layer = layer(4, return_sequences=True)

        x_ragged = keras.Input(shape=(None, 5), ragged=True)
        y_ragged = returning_rnn_layer(x_ragged)
        model = keras.models.Model(x_ragged, y_ragged)
        output_ragged = model.predict(ragged_data, steps=1)
        self.assertAllClose(output_ragged.ragged_rank, ragged_data.ragged_rank)
        self.assertAllClose(output_ragged.row_splits, ragged_data.row_splits)

        x_dense = keras.Input(shape=(3, 5))
        masking = keras.layers.Masking()(x_dense)
        y_dense = returning_rnn_layer(masking)
        model_2 = keras.models.Model(x_dense, y_dense)
        dense_data = ragged_data.to_tensor()
        output_dense = model_2.predict(dense_data, steps=1)
        # Convert the output here to ragged for value comparison
        output_dense = tf.RaggedTensor.from_tensor(
            output_dense, lengths=row_lengths
        )
        self.assertAllClose(output_ragged, output_dense)

        # Check if return sequences and go_backwards outputs are correct
        np.random.seed(100)
        returning_rnn_layer = layer(4, go_backwards=True, return_sequences=True)

        x_ragged = keras.Input(shape=(None, 5), ragged=True)
        y_ragged = returning_rnn_layer(x_ragged)
        model = keras.models.Model(x_ragged, y_ragged)
        output_ragged = model.predict(ragged_data, steps=1)
        self.assertAllClose(output_ragged.ragged_rank, ragged_data.ragged_rank)
        self.assertAllClose(output_ragged.row_splits, ragged_data.row_splits)

        x_dense = keras.Input(shape=(3, 5))
        masking = keras.layers.Masking()(x_dense)
        y_dense = returning_rnn_layer(masking)
        model_2 = keras.models.Model(x_dense, y_dense)
        dense_data = ragged_data.to_tensor()
        output_dense = model_2.predict(dense_data, steps=1)

        # Note that the raw output for dense and ragged input when
        # go_backward=True will be different. Consider following input
        # [[a, b, 0], [c, 0, 0], [d, e, f]] where 0s are masked value.
        # The dense output will be [[0, b, a], [0, 0, c], [f, e, d]] since it
        # will process the whole sequence from the end.
        # While ragged output will be [[b, a], [c], [f, e, d]] since it just
        # ignore the 0s. And if we densify the ragged output, it will by default
        # inserting 0s to the end (rather than from the beginning), which make
        # the output to be [[b, a, 0], [c, 0, 0], [f, e, d]]. With this, we need
        # to verify that reverse(ragged_output.to_tensor()) ==
        # reverse(dense_output)
        output_dense = keras.backend.reverse(output_dense, [1])
        output_dense = tf.RaggedTensor.from_tensor(
            output_dense, lengths=row_lengths
        )

        self.assertAllClose(
            keras.backend.reverse(output_ragged, [1]), output_dense
        )

    def test_stateless_rnn_cell(self):
        class StatelessCell(keras.layers.Layer):
            def __init__(self):
                self.state_size = ((), [], ())
                self.output_size = None
                super().__init__()

            def build(self, input_shape):
                self.output_size = input_shape[-1]

            def call(self, inputs, states):
                return inputs, states

        x = keras.Input((None, 5))
        cell = StatelessCell()
        initial_state = tf.nest.map_structure(lambda t: None, cell.state_size)
        layer = keras.layers.RNN(cell)
        y = layer(x, initial_state=initial_state)
        model = keras.models.Model(x, y)
        model.compile(
            optimizer="rmsprop",
            loss="mse",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 5)))

    @parameterized.parameters(
        [keras.layers.SimpleRNN, gru_v1.GRU, lstm_v1.LSTM, gru.GRU, lstm.LSTM]
    )
    def test_for_enable_caching_device_for_layer(self, layer_cls):
        expected_caching_device = (
            tf.compat.v1.executing_eagerly_outside_functions()
        )
        layer = layer_cls(1)
        self.assertEqual(
            layer.cell._enable_caching_device, expected_caching_device
        )

        # Make sure the config only appears when the none default value is used.
        config = layer.get_config()
        self.assertNotIn("enable_caching_device", config)

        non_default_value = not expected_caching_device
        layer = layer_cls(1, enable_caching_device=non_default_value)
        self.assertEqual(layer.cell._enable_caching_device, non_default_value)
        config = layer.get_config()
        self.assertEqual(config["enable_caching_device"], non_default_value)

    @parameterized.parameters(
        [
            keras.layers.SimpleRNNCell,
            gru_v1.GRUCell,
            lstm_v1.LSTMCell,
            gru.GRUCell,
            lstm.LSTMCell,
        ]
    )
    def test_for_enable_caching_device_for_cell(self, cell_cls):
        expected_caching_device = (
            tf.compat.v1.executing_eagerly_outside_functions()
        )
        cell = cell_cls(1)
        self.assertEqual(cell._enable_caching_device, expected_caching_device)

        # Make sure the config only appears when the none default value is used.
        config = cell.get_config()
        self.assertNotIn("enable_caching_device", config)

        non_default_value = not expected_caching_device
        cell = cell_cls(1, enable_caching_device=non_default_value)
        self.assertEqual(cell._enable_caching_device, non_default_value)
        config = cell.get_config()
        self.assertEqual(config["enable_caching_device"], non_default_value)


class RNNCellWithConstants(keras.layers.Layer):
    def __init__(self, units, constant_size, **kwargs):
        self.units = units
        self.state_size = units
        self.constant_size = constant_size
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.input_kernel = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="uniform",
            name="kernel",
        )
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer="uniform",
            name="recurrent_kernel",
        )
        self.constant_kernel = self.add_weight(
            shape=(self.constant_size, self.units),
            initializer="uniform",
            name="constant_kernel",
        )
        self.built = True

    def call(self, inputs, states, constants):
        [prev_output] = states
        [constant] = constants
        h_input = keras.backend.dot(inputs, self.input_kernel)
        h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
        h_const = keras.backend.dot(constant, self.constant_kernel)
        output = h_input + h_state + h_const
        return output, [output]

    def get_config(self):
        config = {"units": self.units, "constant_size": self.constant_size}
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))


class Minimal2DRNNCell(keras.layers.Layer):
    """The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell.

    Both internal state and output have 2 dimensions and are orthogonal
    between each other.
    """

    def __init__(self, unit_a, unit_b, **kwargs):
        self.unit_a = unit_a
        self.unit_b = unit_b
        self.state_size = tf.TensorShape([unit_a, unit_b])
        self.output_size = tf.TensorShape([unit_a, unit_b])
        super().__init__(**kwargs)

    def build(self, input_shape):
        input_a = input_shape[-2]
        input_b = input_shape[-1]
        self.kernel = self.add_weight(
            shape=(input_a, input_b, self.unit_a, self.unit_b),
            initializer="uniform",
            name="kernel",
        )
        self.recurring_kernel = self.add_weight(
            shape=(self.unit_a, self.unit_b, self.unit_a, self.unit_b),
            initializer="uniform",
            name="recurring_kernel",
        )
        self.bias = self.add_weight(
            shape=(self.unit_a, self.unit_b), initializer="uniform", name="bias"
        )
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = tf.einsum("bij,ijkl->bkl", inputs, self.kernel)
        h += tf.expand_dims(self.bias, axis=0)
        output = h + tf.einsum(
            "bij,ijkl->bkl", prev_output, self.recurring_kernel
        )
        return output, [output]


class PlusOneRNNCell(keras.layers.Layer):
    """Add one to the input and state.

    This cell is used for testing state_size and output_size.
    """

    def __init__(self, num_unit, **kwargs):
        self.state_size = num_unit
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.output_size = input_shape[-1]

    def call(self, inputs, states):
        return inputs + 1, [states[0] + 1]


class NestedCell(keras.layers.Layer):
    def __init__(self, unit_1, unit_2, unit_3, use_tuple=False, **kwargs):
        self.unit_1 = unit_1
        self.unit_2 = unit_2
        self.unit_3 = unit_3
        self.use_tuple = use_tuple
        super().__init__(**kwargs)
        # A nested state.
        if use_tuple:
            self.state_size = NestedState(
                s1=unit_1, s2=tf.TensorShape([unit_2, unit_3])
            )
        else:
            self.state_size = (unit_1, tf.TensorShape([unit_2, unit_3]))
        self.output_size = (unit_1, tf.TensorShape([unit_2, unit_3]))

    def build(self, inputs_shape):
        # expect input_shape to contain 2 items, [(batch, i1), (batch, i2, i3)]
        if self.use_tuple:
            input_1 = inputs_shape.t1[1]
            input_2, input_3 = inputs_shape.t2[1:]
        else:
            input_1 = inputs_shape[0][1]
            input_2, input_3 = inputs_shape[1][1:]

        self.kernel_1 = self.add_weight(
            shape=(input_1, self.unit_1), initializer="uniform", name="kernel_1"
        )
        self.kernel_2_3 = self.add_weight(
            shape=(input_2, input_3, self.unit_2, self.unit_3),
            initializer="uniform",
            name="kernel_2_3",
        )

    def call(self, inputs, states):
        # inputs should be in [(batch, input_1), (batch, input_2, input_3)]
        # state should be in shape [(batch, unit_1), (batch, unit_2, unit_3)]
        flatten_inputs = tf.nest.flatten(inputs)
        s1, s2 = states

        output_1 = tf.matmul(flatten_inputs[0], self.kernel_1)
        output_2_3 = tf.einsum(
            "bij,ijkl->bkl", flatten_inputs[1], self.kernel_2_3
        )
        state_1 = s1 + output_1
        state_2_3 = s2 + output_2_3

        output = [output_1, output_2_3]
        new_states = NestedState(s1=state_1, s2=state_2_3)

        return output, new_states


if __name__ == "__main__":
    tf.test.main()
